// Package sso wraps SSO functionality to be used by Fleet's service layer.
// It uses https://github.com/crewjam/saml for SAML parsing and validation.
//
// Initiate SSO:
//   - Fleet generates a random session ID, and SAML AuthnRequest with a random Request ID.
//   - Fleet stores the session in Redis with the session ID as key and "Request ID" + "Original URL" + configured Metadata as value.
//   - Fleet returns a URL that redirects the user to the IdP with AuthnRequest and the session ID as a HTTP cookie.
//
// Callback SSO:
//   - Fleet receives SAMLResponse in the request.
//   - Fleet loads the session ID from a HTTP cookie and loads the session from Redis.
//   - Fleet uses the Request ID + Metadata loaded from Redis to verify the SAMLResponse.
//   - If verification succeeds, Fleet redirects the user to the Original URL loaded from Redis.
//
// IdP-initiated Callback SSO (if enabled by the admin):
//   - Fleet receives SAMLResponse without session ID or "Request ID".
//   - Fleet uses the configured metadata to verify the SAMLResponse.
//
// PS: We use a HTTP cookie for the session ID to prevent CSRF attacks (outcome of a pentest).
package sso

import (
	"context"
	"net/url"

	"github.com/crewjam/saml"
	"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
	"github.com/fleetdm/fleet/v4/server/fleet"
)

// SAMLProviderFromConfiguredMetadata creates a SAML provider that can validate SAML responses
// from the configured SSO metadata.
func SAMLProviderFromConfiguredMetadata(
	ctx context.Context,
	entityID string,
	acsURL string,
	settings *fleet.SSOProviderSettings,
) (*saml.ServiceProvider, error) {
	entityDescriptor, err := GetMetadata(settings)
	if err != nil {
		return nil, ctxerr.Wrap(ctx, &fleet.BadRequestError{
			Message:     "failed to get and parse IdP metadata",
			InternalErr: err,
		})
	}
	parsedACSURL, err := url.Parse(acsURL)
	if err != nil {
		return nil, ctxerr.Wrap(ctx, err, "failed to parse ACS URL")
	}
	return &saml.ServiceProvider{
		EntityID:          entityID,
		AcsURL:            *parsedACSURL,
		IDPMetadata:       entityDescriptor,
		AuthnNameIDFormat: saml.EmailAddressNameIDFormat,
	}, nil
}

// SAMLProviderFromSession creates a SAML provider that can validate SAML responses
// from a valid SSO session (stored in sessionStore).
func SAMLProviderFromSession(
	ctx context.Context,
	sessionID string,
	sessionStore SessionStore,
	acsURL *url.URL,
	entityID string,
	expectedAudiences []string,
) (samlProvider *saml.ServiceProvider, requestID, originalURL string, ssoRequestData SSORequestData, err error) {
	session, err := sessionStore.Fullfill(sessionID)
	if err != nil {
		return nil, "", "", SSORequestData{}, ctxerr.Wrap(ctx, err, "validate request in session")
	}
	entityDescriptor, err := ParseMetadata([]byte(session.Metadata))
	if err != nil {
		return nil, "", "", SSORequestData{}, ctxerr.Wrap(ctx, err, "failed to parse metadata")
	}

	return &saml.ServiceProvider{
		EntityID:    entityID,
		AcsURL:      *acsURL,
		IDPMetadata: entityDescriptor,
		ValidateAudienceRestriction: func(assertion *saml.Assertion) error {
			return validateAudiences(assertion, expectedAudiences)
		},
	}, session.RequestID, session.OriginalURL, session.RequestData, nil
}

// SAMLProviderFromSessionOrConfiguredMetadata creates a SAML provider that can validate SAML responses.
// It will create the SAML provider from an existing SSO session (using sessionStore),
// if sessionID was generated by Fleet.
// or it will create a SAML provider from the configured metadata if IdP-initiated logins are enabled.
func SAMLProviderFromSessionOrConfiguredMetadata(
	ctx context.Context,
	sessionID string,
	sessionStore SessionStore,
	acsURL *url.URL,
	settings *fleet.SSOSettings,
	expectedAudiences []string,
) (samlProvider *saml.ServiceProvider, requestID string, redirectURL string, err error) {
	idpInitiated := sessionID == ""

	var entityDescriptor *saml.EntityDescriptor
	if settings.EnableSSOIdPLogin && idpInitiated {
		// Missing request ID indicates this was IdP-initiated. Only allow if
		// configured to do so.
		var err error
		entityDescriptor, err = GetMetadata(&settings.SSOProviderSettings)
		if err != nil {
			return nil, "", "", ctxerr.Wrap(ctx, err, "failed to parse metadata")
		}
		redirectURL = "/"
	} else {
		session, err := sessionStore.Fullfill(sessionID)
		if err != nil {
			return nil, "", "", ctxerr.Wrap(ctx, err, "validate request in session")
		}
		entityDescriptor, err = ParseMetadata([]byte(session.Metadata))
		if err != nil {
			return nil, "", "", ctxerr.Wrap(ctx, err, "failed to parse metadata")
		}
		redirectURL = session.OriginalURL
		requestID = session.RequestID
	}

	return &saml.ServiceProvider{
		EntityID:           settings.EntityID,
		AcsURL:             *acsURL,
		DefaultRedirectURI: redirectURL,
		IDPMetadata:        entityDescriptor,
		ValidateAudienceRestriction: func(assertion *saml.Assertion) error {
			return validateAudiences(assertion, expectedAudiences)
		},
		AllowIDPInitiated: settings.EnableSSOIdPLogin,
	}, requestID, redirectURL, nil
}
